/***************************************************************************************************
 * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Unit tests for thread-level GEMM
*/

#pragma once

#include <iostream>

#include "cutlass/gemm/thread/mma.h"
#include "../kernel/thread/testbed_kernel.h"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"

#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"

#include <cuda.h>
#include <nvrtc.h>
#include "../cutlass/nvrtc/environment.h"
#include <assert.h>

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace test {
namespace nvrtc {
namespace thread {

/// Structure to compute the matrix product
template <
        /// Size of the Gemm problem - concept: gemm::GemmShape<>
        typename Shape,
        /// Data type of A elements
        typename ElementA,
        /// Layout of A matrix (concept: MatrixLayout)
        typename LayoutA,
        /// Data type of B elements
        typename ElementB,
        /// Layout of B matrix (concept: MatrixLayout)
        typename LayoutB,
        /// Element type of C matrix
        typename ElementC,
        /// Layout of C matrix (concept: MatrixLayout)
        typename LayoutC>
struct Testbed {
    /// Thread-level matrix multiply-accumulate operator
    using Mma = cutlass::gemm::thread::Mma<Shape, ElementA, LayoutA, ElementB,
                                           LayoutB, ElementC, LayoutC>;

    //
    // Data members
    //

    cutlass::HostTensor<ElementA, LayoutA> tensor_A;
    cutlass::HostTensor<ElementB, LayoutB> tensor_B;
    cutlass::HostTensor<ElementC, LayoutC> tensor_C;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;

    //
    // Methods
    //

    /// Allocates workspace in device memory
    Testbed() {
        tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK));
        tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN));
        tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN),
                                 false);
    }

    static inline bool check_nvrtc_error(nvrtcResult error) {
        if (error != NVRTC_SUCCESS) {
            std::cerr << "failed to compile ";
            return false;
        }
        return true;
    }

    /// Runs the test
    bool run(std::string const& gemm_traits) {
        //
        // initialize device memory
        //

        cutlass::reference::host::BlockFillSequential(tensor_A.host_data(),
                                                      tensor_A.capacity());

        cutlass::reference::host::BlockFillSequential(tensor_B.host_data(),
                                                      tensor_B.capacity(),
                                                      ElementB(1), ElementB(2));

        cutlass::reference::host::TensorFill(tensor_C.host_view(), ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_computed.host_view(),
                                             ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_reference.host_view(),
                                             ElementC(0));

        tensor_A.sync_device();
        tensor_B.sync_device();
        tensor_C.sync_device();
        tensor_D_computed.sync_device();

#if 0
    // launch kernel
    cutlass::gemm::kernel::testbed_kernel<Mma><<< dim3(1, 1), dim3(1, 1, 1) >>>(
      tensor_D_computed.device_data(),
      tensor_A.device_data(),
      tensor_B.device_data(),
      tensor_C.device_data());

#else
        // Instantiate gemm_kernel
        nvrtcResult result_nvrtc;
        nvrtcProgram program;
        static char const* src =
                "#include \"cutlass/gemm/thread/mma.h\"\n"
                "#include \"cutlass/gemm/gemm.h\"\n"
                "#include \"cutlass/layout/matrix.h\"\n"
                "#include \"unit/nvrtc/kernel/thread/testbed_kernel.h\"\n";

        std::string type_name;
#if 0
    // TODO Ideally we'd use nvrtcGetTypeName to determine the type, but it cannot resolve enum symbol names
    // As altername solution we might want to implement to_string<GemmTraits>() to get the traits string.
    nvrtcGetTypeName<typename GemmTraits_>(&type_name);
#else
        type_name = gemm_traits;
#endif

        result_nvrtc = nvrtcCreateProgram(
                &program, src, NULL, (int)cutlass::nvrtc::kCutlassHeaderCount,
                cutlass::nvrtc::kCutlassHeaders,
                cutlass::nvrtc::kCutlassHeaderNames);
        check_nvrtc_error(result_nvrtc);

        std::string gemm_kernel_instantiation =
                "test::nvrtc::kernel::thread::testbed_kernel< " + type_name +
                " >";
        nvrtcAddNameExpression(program, gemm_kernel_instantiation.c_str());

        const char* opts[] = {"--gpu-architecture=compute_75", "--std=c++11",
                              "--include-path=/usr/local/cuda-10.1/include"};

        result_nvrtc = nvrtcCompileProgram(program, 3, opts);
        if (result_nvrtc != NVRTC_SUCCESS) {
            size_t logSize;
            nvrtcGetProgramLogSize(program, &logSize);
            std::vector<char> log(logSize);
            nvrtcGetProgramLog(program, log.data());
            std::cout << "Compile log:" << std::endl << log.data() << std::endl;
        }
        if (!check_nvrtc_error(result_nvrtc)) {
            assert(0);
        }

        // The lowered name is the name of the template instantiation in the
        // generated PTX code.
        char const* gemm_kernel_lowered_name;
        nvrtcGetLoweredName(program, gemm_kernel_instantiation.c_str(),
                            &gemm_kernel_lowered_name);
        if (!check_nvrtc_error(result_nvrtc)) {
            assert(0);
        }

        // Query the size of the genereated PTX so that we can allocate storage
        // and retrieve it afterwards
        size_t ptx_size;
        result_nvrtc = nvrtcGetPTXSize(program, &ptx_size);
        if (!check_nvrtc_error(result_nvrtc)) {
            assert(0);
        }

        std::vector<char> ptx(ptx_size);
        result_nvrtc = nvrtcGetPTX(program, ptx.data());
        if (!check_nvrtc_error(result_nvrtc)) {
            assert(0);
        }

        // we do not need the nvrtc program anymore
        // nvrtcDestroyProgram(&program);

        CUmodule module;
        CUresult result_cuda;
        result_cuda = cuModuleLoadDataEx(&module, ptx.data(), 0, 0, 0);
        if (result_cuda != CUDA_SUCCESS) {
            assert(0);
        }

        CUfunction kernel;
        result_cuda =
                cuModuleGetFunction(&kernel, module, gemm_kernel_lowered_name);
        if (result_cuda != CUDA_SUCCESS) {
            assert(0);
        }

        void* d_a = (void*)tensor_A.device_data();
        void* d_b = (void*)tensor_B.device_data();
        void* d_c = (void*)tensor_C.device_data();
        void* d_d = (void*)tensor_D_computed.device_data();
        void* args[] = {&d_d, &d_a, &d_b, &d_c};

        // CUfunction f, unsigned int  gridDimX, unsigned int  gridDimY,
        // unsigned int  gridDimZ, unsigned int  blockDimX, unsigned int
        // blockDimY, unsigned int  blockDimZ, unsigned int  sharedMemBytes,
        // CUstream hStream, void** kernelParams, void** extra
        result_cuda = cuLaunchKernel(kernel, 1, 1, 1, 1, 1, 1, 0,
                                     0 /*cudaStreamDefault*/, args, 0);
        if (result_cuda != CUDA_SUCCESS) {
            assert(0);
        } else {
        }
#endif

        // verify no errors
        cudaError_t result = cudaDeviceSynchronize();

        if (result != cudaSuccess) {
            std::cout << "CUDA ERROR: " << cudaGetErrorString(result);
            return false;
        }

        tensor_D_computed.sync_host();

        //
        // Reference implementation
        //

        // tensor_D_reference.fill(tensor_C.host_view());

        cutlass::reference::host::Gemm<ElementA, LayoutA, ElementB, LayoutB,
                                       ElementC, LayoutC, ElementC, ElementC>
                reference_gemm;

        reference_gemm({Shape::kM, Shape::kN, Shape::kK}, ElementC(1),
                       tensor_A.host_ref(), tensor_B.host_ref(), ElementC(0),
                       tensor_D_reference.host_ref());

        //
        // Verify equivalence
        //

        // compare
        bool passed = cutlass::reference::host::TensorEquals(
                tensor_D_computed.host_view(), tensor_D_reference.host_view());

        if (!passed)
            std::cout << "A:\n"
                      << tensor_A.host_view() << "\n\n"
                      << "B:\n"
                      << tensor_B.host_view() << "\n\n"
                      << "C:\n"
                      << tensor_C.host_view() << "\n\n"
                      << "Reference:\n"
                      << tensor_D_reference.host_view() << "\n\n"
                      << "Computed:\n"
                      << tensor_D_computed.host_view() << std::endl;

        std::cout << "passed " << passed << std::endl;

        return passed;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace thread
}  // namespace nvrtc
}  // namespace test
